{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "# ERNIE模型初尝试\n",
    "\n",
    "## 项目介绍\n",
    "\n",
    "【项目的最初目的】：使用飞桨已有的官方套件，初步实现一个相似问题（严谨实用向）回复功能。\n",
    "\n",
    "**【项目最终实现】：使用飞桨已有的官方套件，初步实现一个（额，文本生成，无聊向的）回复功能，看一下通过对烦恼类的简短的问题+回复的学习，最终能自动生成什么回复。**\n",
    "\n",
    "\n",
    "# 数据介绍\n",
    "选择飞桨自带的baike_qa2019 百科类问答json版数据集（链接：https://aistudio.baidu.com/aistudio/competition/detail/63/0/task-definition）\n",
    "数据集为压缩文件，解压后有2个子文件。训练数据集baike_qa_train.json和测试数据集baike_qa_valid.json。\n",
    "\n",
    "每条训练数据是一个字典形式，key有5个：qid问答id，category问答划分的类别(如教育/科学-理工学科-地球科学)，title问题，desc问题描述，answer回答，如下所示：\n",
    "{'qid': 'qid_5982723620932473219', 'category': '教育/科学-理工学科-地球科学', 'title': '人站在地球上为什么没有头朝下的感觉 ', 'desc': '', 'answer': '地球上重力作用一直是指向球心的，因此\\r\\n只要头远离球心，人们就回感到头朝上。'} \n",
    "\n",
    "\n",
    "***----因为模型选择已偏航(文本生成，认为返回的答案并不严谨)，就选择了category是烦恼的类别，问题和回复都很简短的（认为简短的相对不那么严肃）问题和答案进行学习***\n",
    "\n",
    "\n",
    "\n",
    "# 模型介绍\n",
    "计划：通过计算文本相似性来进行问题匹配，返回和输入最大相似的问题的回复\n",
    "\n",
    "由于：目前对飞桨了解不够充分，没有找到合适的模型。通过在项目集搜索，找到了使用paddlehub的ERNIE-GEN进行问答系统搭建的项目。\n",
    "\n",
    "ERNIE的github网址为 https://github.com/PaddlePaddle/ERNIE ，里面介绍：ERNIE在工业界得到了大规模应用，如搜索引擎、新闻推荐、广告系统、语音交互、智能客服等。\n",
    "而2020.5开源的 ERNIE-GEN 模型是最强文本生成预训练模型。\n",
    "\n",
    "所以ernie-gen可能更偏向于文本生成，非返回原有的资源库结果。但是目前没有发现适合的其他官方套件，所以想用ernie-gen尝试一下。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "# 安装依赖包"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple\n",
      "Collecting paddlehub==2.2.0\n",
      "  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/1e/86/7184a1c76a3ca9bb47cbf838df6479164c21849da751452b5d11eea4140d/paddlehub-2.2.0-py3-none-any.whl (212 kB)\n",
      "     |████████████████████████████████| 212 kB 1.4 MB/s            \n",
      "\u001b[?25hRequirement already satisfied: packaging in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlehub==2.2.0) (21.3)\n",
      "Requirement already satisfied: matplotlib in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlehub==2.2.0) (2.2.3)\n",
      "Requirement already satisfied: opencv-python in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlehub==2.2.0) (4.1.1.26)\n",
      "Requirement already satisfied: colorlog in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlehub==2.2.0) (4.1.0)\n",
      "Requirement already satisfied: pyyaml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlehub==2.2.0) (5.1.2)\n",
      "Requirement already satisfied: rarfile in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlehub==2.2.0) (3.1)\n",
      "Requirement already satisfied: tqdm in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlehub==2.2.0) (4.27.0)\n",
      "Requirement already satisfied: numpy in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlehub==2.2.0) (1.19.5)\n",
      "Collecting paddle2onnx>=0.5.1\n",
      "  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/91/b9/c87d6c35f946a965b72ebccb078ec9e43016749f4aa3cd5eb9297c8199b3/paddle2onnx-0.9.1-py3-none-any.whl (95 kB)\n",
      "     |████████████████████████████████| 95 kB 1.5 MB/s             \n",
      "\u001b[?25hRequirement already satisfied: Pillow in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlehub==2.2.0) (8.2.0)\n",
      "Requirement already satisfied: flask>=1.1.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlehub==2.2.0) (1.1.1)\n",
      "Requirement already satisfied: pyzmq in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlehub==2.2.0) (22.3.0)\n",
      "Requirement already satisfied: paddlenlp>=2.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlehub==2.2.0) (2.1.1)\n",
      "Requirement already satisfied: gunicorn>=19.10.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlehub==2.2.0) (20.0.4)\n",
      "Requirement already satisfied: filelock in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlehub==2.2.0) (3.0.12)\n",
      "Requirement already satisfied: colorama in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlehub==2.2.0) (0.4.4)\n",
      "Requirement already satisfied: visualdl>=2.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlehub==2.2.0) (2.2.0)\n",
      "Requirement already satisfied: easydict in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlehub==2.2.0) (1.9)\n",
      "Requirement already satisfied: Werkzeug>=0.15 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.0->paddlehub==2.2.0) (0.16.0)\n",
      "Requirement already satisfied: itsdangerous>=0.24 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.0->paddlehub==2.2.0) (1.1.0)\n",
      "Requirement already satisfied: Jinja2>=2.10.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.0->paddlehub==2.2.0) (2.11.0)\n",
      "Requirement already satisfied: click>=5.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.0->paddlehub==2.2.0) (7.0)\n",
      "Requirement already satisfied: setuptools>=3.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from gunicorn>=19.10.0->paddlehub==2.2.0) (56.2.0)\n",
      "Requirement already satisfied: protobuf in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddle2onnx>=0.5.1->paddlehub==2.2.0) (3.14.0)\n",
      "Requirement already satisfied: six in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddle2onnx>=0.5.1->paddlehub==2.2.0) (1.16.0)\n",
      "Collecting onnx<=1.9.0\n",
      "  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/3f/9b/54c950d3256e27f970a83cd0504efb183a24312702deed0179453316dbd0/onnx-1.9.0-cp37-cp37m-manylinux2010_x86_64.whl (12.2 MB)\n",
      "     |████████████████████████████████| 12.2 MB 846 kB/s            \n",
      "\u001b[?25hRequirement already satisfied: paddlefsl==1.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlenlp>=2.0.0->paddlehub==2.2.0) (1.0.0)\n",
      "Requirement already satisfied: multiprocess in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlenlp>=2.0.0->paddlehub==2.2.0) (0.70.11.1)\n",
      "Requirement already satisfied: h5py in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlenlp>=2.0.0->paddlehub==2.2.0) (2.9.0)\n",
      "Requirement already satisfied: jieba in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlenlp>=2.0.0->paddlehub==2.2.0) (0.42.1)\n",
      "Requirement already satisfied: seqeval in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlenlp>=2.0.0->paddlehub==2.2.0) (1.2.2)\n",
      "Requirement already satisfied: requests~=2.24.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlefsl==1.0.0->paddlenlp>=2.0.0->paddlehub==2.2.0) (2.24.0)\n",
      "Requirement already satisfied: pandas in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl>=2.0.0->paddlehub==2.2.0) (1.1.5)\n",
      "Requirement already satisfied: pre-commit in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl>=2.0.0->paddlehub==2.2.0) (1.21.0)\n",
      "Requirement already satisfied: Flask-Babel>=1.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl>=2.0.0->paddlehub==2.2.0) (1.0.0)\n",
      "Requirement already satisfied: shellcheck-py in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl>=2.0.0->paddlehub==2.2.0) (0.7.1.1)\n",
      "Requirement already satisfied: flake8>=3.7.9 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl>=2.0.0->paddlehub==2.2.0) (4.0.1)\n",
      "Requirement already satisfied: bce-python-sdk in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl>=2.0.0->paddlehub==2.2.0) (0.8.53)\n",
      "Requirement already satisfied: python-dateutil>=2.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib->paddlehub==2.2.0) (2.8.2)\n",
      "Requirement already satisfied: kiwisolver>=1.0.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib->paddlehub==2.2.0) (1.1.0)\n",
      "Requirement already satisfied: pytz in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib->paddlehub==2.2.0) (2019.3)\n",
      "Requirement already satisfied: cycler>=0.10 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib->paddlehub==2.2.0) (0.10.0)\n",
      "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib->paddlehub==2.2.0) (3.0.7)\n",
      "Requirement already satisfied: mccabe<0.7.0,>=0.6.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl>=2.0.0->paddlehub==2.2.0) (0.6.1)\n",
      "Requirement already satisfied: pyflakes<2.5.0,>=2.4.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl>=2.0.0->paddlehub==2.2.0) (2.4.0)\n",
      "Requirement already satisfied: importlib-metadata<4.3 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl>=2.0.0->paddlehub==2.2.0) (4.2.0)\n",
      "Requirement already satisfied: pycodestyle<2.9.0,>=2.8.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl>=2.0.0->paddlehub==2.2.0) (2.8.0)\n",
      "Requirement already satisfied: Babel>=2.3 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Flask-Babel>=1.0.0->visualdl>=2.0.0->paddlehub==2.2.0) (2.8.0)\n",
      "Requirement already satisfied: MarkupSafe>=0.23 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Jinja2>=2.10.1->flask>=1.1.0->paddlehub==2.2.0) (2.0.1)\n",
      "Requirement already satisfied: typing-extensions>=3.6.2.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from onnx<=1.9.0->paddle2onnx>=0.5.1->paddlehub==2.2.0) (4.0.1)\n",
      "Requirement already satisfied: idna<3,>=2.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests~=2.24.0->paddlefsl==1.0.0->paddlenlp>=2.0.0->paddlehub==2.2.0) (2.8)\n",
      "Requirement already satisfied: chardet<4,>=3.0.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests~=2.24.0->paddlefsl==1.0.0->paddlenlp>=2.0.0->paddlehub==2.2.0) (3.0.4)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests~=2.24.0->paddlefsl==1.0.0->paddlenlp>=2.0.0->paddlehub==2.2.0) (2019.9.11)\n",
      "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests~=2.24.0->paddlefsl==1.0.0->paddlenlp>=2.0.0->paddlehub==2.2.0) (1.25.6)\n",
      "Requirement already satisfied: pycryptodome>=3.8.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from bce-python-sdk->visualdl>=2.0.0->paddlehub==2.2.0) (3.9.9)\n",
      "Requirement already satisfied: future>=0.6.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from bce-python-sdk->visualdl>=2.0.0->paddlehub==2.2.0) (0.18.0)\n",
      "Requirement already satisfied: dill>=0.3.3 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from multiprocess->paddlenlp>=2.0.0->paddlehub==2.2.0) (0.3.3)\n",
      "Requirement already satisfied: identify>=1.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl>=2.0.0->paddlehub==2.2.0) (1.4.10)\n",
      "Requirement already satisfied: nodeenv>=0.11.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl>=2.0.0->paddlehub==2.2.0) (1.3.4)\n",
      "Requirement already satisfied: toml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl>=2.0.0->paddlehub==2.2.0) (0.10.0)\n",
      "Requirement already satisfied: cfgv>=2.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl>=2.0.0->paddlehub==2.2.0) (2.0.1)\n",
      "Requirement already satisfied: virtualenv>=15.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl>=2.0.0->paddlehub==2.2.0) (16.7.9)\n",
      "Requirement already satisfied: aspy.yaml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl>=2.0.0->paddlehub==2.2.0) (1.3.0)\n",
      "Requirement already satisfied: scikit-learn>=0.21.3 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from seqeval->paddlenlp>=2.0.0->paddlehub==2.2.0) (0.24.2)\n",
      "Requirement already satisfied: zipp>=0.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from importlib-metadata<4.3->flake8>=3.7.9->visualdl>=2.0.0->paddlehub==2.2.0) (3.7.0)\n",
      "Requirement already satisfied: scipy>=0.19.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-learn>=0.21.3->seqeval->paddlenlp>=2.0.0->paddlehub==2.2.0) (1.6.3)\n",
      "Requirement already satisfied: threadpoolctl>=2.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-learn>=0.21.3->seqeval->paddlenlp>=2.0.0->paddlehub==2.2.0) (2.1.0)\n",
      "Requirement already satisfied: joblib>=0.11 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-learn>=0.21.3->seqeval->paddlenlp>=2.0.0->paddlehub==2.2.0) (0.14.1)\n",
      "Installing collected packages: onnx, paddle2onnx, paddlehub\n",
      "  Attempting uninstall: paddlehub\n",
      "    Found existing installation: paddlehub 2.0.4\n",
      "    Uninstalling paddlehub-2.0.4:\n",
      "      Successfully uninstalled paddlehub-2.0.4\n",
      "Successfully installed onnx-1.9.0 paddle2onnx-0.9.1 paddlehub-2.2.0\n",
      "\u001b[33mWARNING: You are using pip version 21.3.1; however, version 22.0.3 is available.\n",
      "You should consider upgrading via the '/opt/conda/envs/python35-paddle120-env/bin/python -m pip install --upgrade pip' command.\u001b[0m\n",
      "Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple\n",
      "Collecting paddle-ernie\n",
      "  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/d4/09/027ad18a04a43737d74c2e158952e7c268c84ae9d2e60ce32a258100c9b3/paddle-ernie-0.2.0.dev1.tar.gz (33 kB)\n",
      "  Preparing metadata (setup.py) ... \u001b[?25ldone\n",
      "\u001b[?25hRequirement already satisfied: requests in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddle-ernie) (2.24.0)\n",
      "Requirement already satisfied: tqdm in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddle-ernie) (4.27.0)\n",
      "Collecting pathlib2\n",
      "  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/09/eb/4af4bcd5b8731366b676192675221c5324394a580dfae469d498313b5c4a/pathlib2-2.3.7.post1-py2.py3-none-any.whl (18 kB)\n",
      "Requirement already satisfied: six in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pathlib2->paddle-ernie) (1.16.0)\n",
      "Requirement already satisfied: idna<3,>=2.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->paddle-ernie) (2.8)\n",
      "Requirement already satisfied: chardet<4,>=3.0.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->paddle-ernie) (3.0.4)\n",
      "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->paddle-ernie) (1.25.6)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->paddle-ernie) (2019.9.11)\n",
      "Building wheels for collected packages: paddle-ernie\n",
      "  Building wheel for paddle-ernie (setup.py) ... \u001b[?25ldone\n",
      "\u001b[?25h  Created wheel for paddle-ernie: filename=paddle_ernie-0.2.0.dev1-py3-none-any.whl size=25548 sha256=e0e0613d3a581623b0829ffe5a5f567d33885486b500b8b6fc61dcdb7c31caf9\n",
      "  Stored in directory: /home/aistudio/.cache/pip/wheels/cf/b5/ea/1dc4bdb751cd206d020538bf525ff58d5a52fc4e0df4531123\n",
      "Successfully built paddle-ernie\n",
      "Installing collected packages: pathlib2, paddle-ernie\n",
      "Successfully installed paddle-ernie-0.2.0.dev1 pathlib2-2.3.7.post1\n",
      "\u001b[33mWARNING: You are using pip version 21.3.1; however, version 22.0.3 is available.\n",
      "You should consider upgrading via the '/opt/conda/envs/python35-paddle120-env/bin/python -m pip install --upgrade pip' command.\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "!pip install paddlehub==2.2.0\r\n",
    "!pip install paddle-ernie"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# 1.4 持久化安装, 需要使用持久化路径\n",
    "# !mkdir /home/aistudio/external-libraries\n",
    "# !pip install paddle-ernie -t /home/aistudio/external-libraries\n",
    "# !pip install paddlehub==2.1.1 -t /home/aistudio/external-libraries\n",
    "\n",
    "# 1.5 持久化安装后，添加如下代码, 这样每次环境(kernel)启动的时候只要运行下方代码即可: \n",
    "# import sys \n",
    "# sys.path.append('/home/aistudio/external-libraries')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "# 模型训练\n",
    "## 数据处理\n",
    "ERNIE-GEN模型要求的数据输入格式是：编号\\t文本1\\t文本2的格式，和原数据格式不符合，故先对数据进行处理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "baike_qa_train.json  baike_qa_valid.json  百科类问答json版.zip\r\n"
     ]
    }
   ],
   "source": [
    "# 1.1 查看工作区文件, 该目录下的变更将会持久保存. 请及时清理不必要的文件, 避免加载过慢.\n",
    "# !ls /home/aistudio/work\n",
    "\n",
    "# 1.2 解压挂载的数据集在同级目录下\n",
    "!unzip -oq data/data107726/百科类问答json版.zip -d data/data107726/\n",
    "\n",
    "# 1.3 查看当前挂载的数据集目录, 该目录下的变更重启环境后会自动还原\n",
    "!ls /home/aistudio/data/data107726"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "训练集数据量：181, 验证集数据量：28\n",
      "\n",
      "训练集数据样例：\n",
      " 0\t婆媳关系该怎样处理 \t少见面，见面时投其所好，少告状。\n"
     ]
    }
   ],
   "source": [
    "# 2. 处理数据格式\r\n",
    "# 根据模型要求，处理成：编号\\t问题\\t答案的形式\r\n",
    "# 由于数据量太大，内存不足会报错，限制了总文本的长度\r\n",
    "import json\r\n",
    "\r\n",
    "def resave_2txt(raw_path, raw_encod, save_path, save_encod, max_len=None):\r\n",
    "    try:\r\n",
    "        max_len = int(max_len)\r\n",
    "    except:\r\n",
    "        max_len = float('inf')\r\n",
    "    tmp_len = 0\r\n",
    "    save_txt = []\r\n",
    "    with open(raw_path,'r',encoding=raw_encod) as f:\r\n",
    "        for line in f.readlines():\r\n",
    "            dic_v = json.loads(line)\r\n",
    "            if '烦恼' in dic_v['category'].split('-'):\r\n",
    "                txt_line = str(tmp_len) + '\\t' + dic_v['title'].replace('\\n','').replace('\\r','')  + '\\t' + dic_v['answer'].replace('\\n','').replace('\\r','')\r\n",
    "                if len(txt_line)>30:\r\n",
    "                    pass\r\n",
    "                else:\r\n",
    "                    if tmp_len <= max_len:\r\n",
    "                        save_txt.append(txt_line)\r\n",
    "                        tmp_len += 1\r\n",
    "                        with open(save_path, 'a+', encoding=save_encod) as f:\r\n",
    "                            f.write(txt_line + '\\n')\r\n",
    "                    else:\r\n",
    "                        break\r\n",
    "    \r\n",
    "    return save_txt\r\n",
    "\r\n",
    "\r\n",
    "train_txt = resave_2txt('data/data107726/baike_qa_train.json', 'utf-8', 'data/data107726/baike_qa_train.txt', 'utf-8', 180)\r\n",
    "valid_txt = resave_2txt('data/data107726/baike_qa_valid.json', 'utf-8', 'data/data107726/baike_qa_valid.txt', 'utf-8', 30)\r\n",
    "print(f'训练集数据量：{len(train_txt)}, 验证集数据量：{len(valid_txt)}\\n')\r\n",
    "print('训练集数据样例：\\n', train_txt[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "## 模型训练\r\n",
    "由于对自己的要求是使用已有套件，初步实现目的。除了报错之外，对于参数暂时没有太多研究。\r\n",
    "\r\n",
    "训练时发现的报错有：\r\n",
    "1. 数据集的处理：一些问题回复本身含\\n，需要替换成其他字符，否则保存数据的时候，\\n之后的数据单独为一条记录，不符合要求的数据格式(编号\\t文本\\t文本)\r\n",
    "2. max_encode_len(int)最长编码长度 和  max_decode_len(int) 最长解码长度，受输入的数据长度影响，若数据长度较长，而此处设置的比较小，则会报错\r\n",
    "3. batch_size：有一些报错会提出需要调小batch_size\r\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Download https://bj.bcebos.com/paddlehub/paddlehub_dev/ernie_gen_1.1.0.tar.gz\n",
      "[##################################################] 100.00%\n",
      "Decompress /home/aistudio/.paddlehub/tmp/tmpc3541uc6/ernie_gen_1.1.0.tar.gz\n",
      "[##################################################] 100.00%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[2022-03-02 16:55:18,192] [    INFO] - Successfully installed ernie_gen-1.1.0\n",
      "[2022-03-02 16:55:18,195] [    INFO] - Downloading https://paddlenlp.bj.bcebos.com/models/transformers/ernie/vocab.txt and saved to /home/aistudio/.paddlenlp/models/ernie-1.0\n",
      "[2022-03-02 16:55:18,197] [    INFO] - Downloading vocab.txt from https://paddlenlp.bj.bcebos.com/models/transformers/ernie/vocab.txt\n",
      "100%|██████████| 90/90 [00:00<00:00, 2228.09it/s]\n",
      "[2022-03-02 16:55:18,334] [    INFO] - Downloading https://paddlenlp.bj.bcebos.com/models/transformers/ernie/ernie_v1_chn_base.pdparams and saved to /home/aistudio/.paddlenlp/models/ernie-1.0\n",
      "100%|██████████| 392507/392507 [00:08<00:00, 44568.18it/s]\n",
      "[2022-03-02 16:55:27,219] [   DEBUG] - init ErnieModel with config: {'attention_probs_dropout_prob': 0.1, 'hidden_act': 'relu', 'hidden_dropout_prob': 0.1, 'hidden_size': 768, 'initializer_range': 0.02, 'max_position_embeddings': 513, 'num_attention_heads': 12, 'num_hidden_layers': 12, 'type_vocab_size': 2, 'vocab_size': 18000, 'pad_token_id': 0}\n",
      "W0302 16:55:27.223588   140 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1\n",
      "W0302 16:55:27.228751   140 device_context.cc:465] device: 0, cuDNN Version: 7.6.\n",
      "[2022-03-02 16:55:31,836] [    INFO] - loading pretrained model from /home/aistudio/.paddlenlp/models/ernie-1.0/ernie_v1_chn_base.pdparams\n",
      "[2022-03-02 16:55:32,651] [    INFO] - param:mlm_bias not set in pretrained model, skip\n",
      "[2022-03-02 16:55:32,653] [    INFO] - param:mlm.weight not set in pretrained model, skip\n",
      "[2022-03-02 16:55:32,655] [    INFO] - param:mlm.bias not set in pretrained model, skip\n",
      "[2022-03-02 16:55:32,657] [    INFO] - param:mlm_ln.weight not set in pretrained model, skip\n",
      "[2022-03-02 16:55:32,659] [    INFO] - param:mlm_ln.bias not set in pretrained model, skip\n",
      "[2022-03-02 16:55:36,101] [    INFO] - [step 20 / 600]train loss 47.87453, ppl 618931878220596772864.00000, elr 1.667e-05\n",
      "[2022-03-02 16:55:38,964] [    INFO] - [step 40 / 600]train loss 18.81637, ppl 148540480.00000, elr 3.333e-05\n",
      "[2022-03-02 16:55:41,844] [    INFO] - [step 60 / 600]train loss 12.65310, ppl 312731.12500, elr 5.000e-05\n",
      "[2022-03-02 16:55:44,725] [    INFO] - [step 80 / 600]train loss 12.43987, ppl 252678.85938, elr 4.815e-05\n",
      "[2022-03-02 16:55:47,483] [    INFO] - [step 100 / 600]train loss 9.05701, ppl 8578.49316, elr 4.630e-05\n",
      "[2022-03-02 16:55:50,256] [    INFO] - [step 120 / 600]train loss 13.27389, ppl 581805.43750, elr 4.444e-05\n",
      "[2022-03-02 16:55:53,126] [    INFO] - [step 140 / 600]train loss 8.29847, ppl 4017.71094, elr 4.259e-05\n",
      "[2022-03-02 16:55:55,967] [    INFO] - [step 160 / 600]train loss 8.03882, ppl 3098.94019, elr 4.074e-05\n",
      "[2022-03-02 16:55:58,810] [    INFO] - [step 180 / 600]train loss 8.46049, ppl 4724.39600, elr 3.889e-05\n",
      "[2022-03-02 16:56:01,522] [    INFO] - [step 200 / 600]train loss 8.07915, ppl 3226.49609, elr 3.704e-05\n",
      "[2022-03-02 16:56:01,525] [    INFO] - save the model in ernie_gen_result/step_200_ppl_3226.49609.params\n",
      "[2022-03-02 16:56:02,794] [    INFO] - Evaluating...\n",
      "[2022-03-02 16:56:53,149] [    INFO] - Rouge-1: 2.55319 ,Rouge-2: 0.00000\n",
      "[2022-03-02 16:56:53,153] [   DEBUG] - ['好像是火花塞', '走唱歌去', '有但是太少了,还是要小心些.']\n",
      "[2022-03-02 16:56:53,154] [   DEBUG] - ['', '有有。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。》。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。', '有。。有。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。']\n",
      "[2022-03-02 16:56:56,007] [    INFO] - [step 220 / 600]train loss 6.53635, ppl 689.76257, elr 3.519e-05\n",
      "[2022-03-02 16:56:58,811] [    INFO] - [step 240 / 600]train loss 5.40688, ppl 222.93576, elr 3.333e-05\n",
      "[2022-03-02 16:57:01,565] [    INFO] - [step 260 / 600]train loss 6.47746, ppl 650.31976, elr 3.148e-05\n",
      "[2022-03-02 16:57:04,366] [    INFO] - [step 280 / 600]train loss 5.42990, ppl 228.12624, elr 2.963e-05\n",
      "[2022-03-02 16:57:07,295] [    INFO] - [step 300 / 600]train loss 5.87104, ppl 354.61871, elr 2.778e-05\n",
      "[2022-03-02 16:57:10,155] [    INFO] - [step 320 / 600]train loss 5.37416, ppl 215.75903, elr 2.593e-05\n",
      "[2022-03-02 16:57:13,038] [    INFO] - [step 340 / 600]train loss 6.00795, ppl 406.64951, elr 2.407e-05\n",
      "[2022-03-02 16:57:15,768] [    INFO] - [step 360 / 600]train loss 4.60535, ppl 100.01751, elr 2.222e-05\n",
      "[2022-03-02 16:57:18,383] [    INFO] - [step 380 / 600]train loss 6.76781, ppl 869.40442, elr 2.037e-05\n",
      "[2022-03-02 16:57:20,932] [    INFO] - [step 400 / 600]train loss 5.77120, ppl 320.92343, elr 1.852e-05\n",
      "[2022-03-02 16:57:20,935] [    INFO] - save the model in ernie_gen_result/step_400_ppl_320.92343.params\n",
      "[2022-03-02 16:57:22,198] [    INFO] - Evaluating...\n",
      "[2022-03-02 16:58:12,197] [    INFO] - Rouge-1: 3.82979 ,Rouge-2: 0.45872\n",
      "[2022-03-02 16:58:12,200] [   DEBUG] - ['好像是火花塞', '走唱歌去', '有但是太少了,还是要小心些.']\n",
      "[2022-03-02 16:58:12,202] [   DEBUG] - ['很不有的的的的！的................................................。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。', '很!很!!!!!!!!!!!!!!!!!!!!.!..!...!...!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!.!.!!.!!!!!!!!!!!', '需破没没的的的的的的！的！']\n",
      "[2022-03-02 16:58:14,896] [    INFO] - [step 420 / 600]train loss 2.81316, ppl 16.66254, elr 1.667e-05\n",
      "[2022-03-02 16:58:17,512] [    INFO] - [step 440 / 600]train loss 5.50030, ppl 244.76544, elr 1.481e-05\n",
      "[2022-03-02 16:58:20,068] [    INFO] - [step 460 / 600]train loss 4.56213, ppl 95.78728, elr 1.296e-05\n",
      "[2022-03-02 16:58:22,646] [    INFO] - [step 480 / 600]train loss 4.09717, ppl 60.17005, elr 1.111e-05\n",
      "[2022-03-02 16:58:25,237] [    INFO] - [step 500 / 600]train loss 4.15829, ppl 63.96193, elr 9.259e-06\n",
      "[2022-03-02 16:58:27,878] [    INFO] - [step 520 / 600]train loss 4.54431, ppl 94.09571, elr 7.407e-06\n",
      "[2022-03-02 16:58:30,501] [    INFO] - [step 540 / 600]train loss 5.04154, ppl 154.70753, elr 5.556e-06\n",
      "[2022-03-02 16:58:33,145] [    INFO] - [step 560 / 600]train loss 4.32917, ppl 75.88101, elr 3.704e-06\n",
      "[2022-03-02 16:58:35,748] [    INFO] - [step 580 / 600]train loss 4.70267, ppl 110.24066, elr 1.852e-06\n",
      "[2022-03-02 16:58:38,338] [    INFO] - [step 600 / 600]train loss 6.62177, ppl 751.27075, elr 0.000e+00\n",
      "[2022-03-02 16:58:38,341] [    INFO] - save the model in ernie_gen_result/step_600_ppl_751.27075.params\n",
      "[2022-03-02 16:58:39,597] [    INFO] - Evaluating...\n",
      "[2022-03-02 16:59:08,034] [    INFO] - Rouge-1: 4.25532 ,Rouge-2: 0.00000\n",
      "[2022-03-02 16:59:08,037] [   DEBUG] - ['好像是火花塞', '走唱歌去', '有但是太少了,还是要小心些.']\n",
      "[2022-03-02 16:59:08,039] [   DEBUG] - ['不是不有的的的。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。', '很!很!!!!!!!!!!!!!!', 'anncts 5300没没没挡的的']\n"
     ]
    }
   ],
   "source": [
    "# 3.训练模型\r\n",
    "# 暂时无脑先走流程，流程走完再调具体\r\n",
    "# 具体流程和参数可参考链接：https://aistudio.baidu.com/aistudio/projectdetail/878918?channelType=0&channel=0\r\n",
    "\r\n",
    "import paddlehub as hub\r\n",
    "module = hub.Module(name=\"ernie_gen\")\r\n",
    "\r\n",
    "result = module.finetune(\r\n",
    "    use_gpu = True,\r\n",
    "    train_path='data/data107726/baike_qa_train.txt',\r\n",
    "    dev_path='data/data107726/baike_qa_valid.txt',\r\n",
    "    save_dir=\"ernie_gen_result\",\r\n",
    "    max_steps=600,\r\n",
    "    max_encode_len=100,\r\n",
    "    max_decode_len=100,\r\n",
    "    noise_prob=0.2,\r\n",
    "    batch_size=2,\r\n",
    "    log_interval=20\r\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'last_save_path': 'ernie_gen_result/step_600_ppl_751.27075.params', 'last_ppl': 751.27075}\n"
     ]
    }
   ],
   "source": [
    "# 查看训练结束时的模型保存路径 和 模型困惑度\r\n",
    "print(result)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[2022-03-02 16:59:20,513] [    INFO] - Begin export the model save in ernie_gen_result/step_600_ppl_751.27075.params ...\n",
      "[2022-03-02 16:59:21,255] [    INFO] - The module has exported to /home/aistudio/ernie_gen_baikeqa\n"
     ]
    }
   ],
   "source": [
    "# 导出模型\r\n",
    "module.export(params_path=result['last_save_path'], module_name=\"ernie_gen_baikeqa\", author=\"zhuqianer\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2022-03-02 17:00:03,060] [    INFO] - Successfully installed ernie_gen_baikeqa-1.0.0\n"
     ]
    }
   ],
   "source": [
    "# 安装module\r\n",
    "!hub install ernie_gen_baikeqa"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['0\\t婆媳关系该怎样处理 \\t少见面，见面时投其所好，少告状。', '1\\t爱情的力量有多大？ \\t拥有了真爱就像拥有了全世界', '2\\t交朋友！想恋爱！！第一次怎么表达啊？ \\t自然点说出真心话', '3\\t爱情是什么感觉？爱一个人是什么感觉呢？ \\t时刻为TA着想', '4\\t你最爱的人是谁？ \\t是我自己', '5\\t学生爱上老师该怎么办 \\t想办法搞定她', '6\\t做爱时该想些什么呢？ \\t什么也不想想,舒服就好!', '7\\t军人婚姻 \\t军嫂都是伟大滴`~`~`~', '8\\t怎么区分是处男不是处男 \\t我路过的来赚分的', '9\\t谁是世界上最乖的人？MANORWOMAN \\tBABY']\n"
     ]
    }
   ],
   "source": [
    "print(train_txt[:10])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "## 模型预测\n",
    "\n",
    "1. 不同的问题返回的答案还是有区别的；\n",
    "\n",
    "2. 和训练集的答案可以说是完全无关了，但是还是挺好玩的\n",
    "\n",
    "3. 回复中有很多重复字和标点，应该是训练数据的问题"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[2022-03-02 17:04:51,290] [    INFO] - Already cached /home/aistudio/.paddlenlp/models/ernie-1.0/ernie_v1_chn_base.pdparams\n",
      "[2022-03-02 17:04:51,292] [   DEBUG] - init ErnieModel with config: {'attention_probs_dropout_prob': 0.1, 'hidden_act': 'relu', 'hidden_dropout_prob': 0.1, 'hidden_size': 768, 'initializer_range': 0.02, 'max_position_embeddings': 513, 'num_attention_heads': 12, 'num_hidden_layers': 12, 'type_vocab_size': 2, 'vocab_size': 18000, 'pad_token_id': 0}\n",
      "[2022-03-02 17:04:54,348] [    INFO] - loading pretrained model from /home/aistudio/.paddlenlp/models/ernie-1.0/ernie_v1_chn_base.pdparams\n",
      "[2022-03-02 17:04:55,304] [    INFO] - param:mlm_bias not set in pretrained model, skip\n",
      "[2022-03-02 17:04:55,306] [    INFO] - param:mlm.weight not set in pretrained model, skip\n",
      "[2022-03-02 17:04:55,308] [    INFO] - param:mlm.bias not set in pretrained model, skip\n",
      "[2022-03-02 17:04:55,309] [    INFO] - param:mlm_ln.weight not set in pretrained model, skip\n",
      "[2022-03-02 17:04:55,311] [    INFO] - param:mlm_ln.bias not set in pretrained model, skip\n",
      "[2022-03-02 17:04:56,863] [    INFO] - Already cached /home/aistudio/.paddlenlp/models/ernie-1.0/vocab.txt\n",
      "[2022-03-02 17:04:56,878] [ WARNING] - use_gpu has been set False as you didn't set the environment variable CUDA_VISIBLE_DEVICES while using use_gpu=True\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['我方方方方方方方方方方方方方。方。。。。']\n",
      "['爱你你啊啊啊啊啊啊啊啊!!!!!!!!!!!!!!!.............!!!!!!!!!!!']\n"
     ]
    }
   ],
   "source": [
    "# 测试\r\n",
    "\r\n",
    "import paddlehub as hub\r\n",
    "\r\n",
    "module = hub.Module(name=\"ernie_gen_baikeqa\")\r\n",
    "\r\n",
    "test_texts = ['婆媳吵架怎么办','我帅吗']\r\n",
    "# generate包含3个参数，texts为输入文本列表，use_gpu指定是否使用gpu，beam_width指定beam search宽度。\r\n",
    "results = module.generate(texts=test_texts, use_gpu=True, beam_width=1)\r\n",
    "for result in results:\r\n",
    "    print(result)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "## 个人总结\n",
    "1. 选模型之前一定要了解清楚模型的作用是什么。就像我原计划做一个严肃的可追溯的问答回复，结果选了一个文本生成模型，就变成了自己拿来刷着玩的回复。\n",
    "\n",
    "2. 接下来就是需要了解模型的输入输出、参数之类的，可以有针对性的进行参数调节（我目前只是一个初步实现，只了解了用到的参数，其余参数待探索）\n",
    "\n",
    "3. 了解模型背后的原理，可能就会知道为什么上面这个模型被训练成了一个容易激动（很多标点和重复字）的机器人了\n",
    "\n",
    "前路漫漫亦灿灿，加油！\n",
    "\n",
    "\n",
    "\n",
    "# 提交链接\n",
    "aistudio链接：https://aistudio.baidu.com/aistudio/projectdetail/3521598?contributionType=1\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "请点击[此处](https://ai.baidu.com/docs#/AIStudio_Project_Notebook/a38e5576)查看本环境基本用法.  <br>\n",
    "Please click [here ](https://ai.baidu.com/docs#/AIStudio_Project_Notebook/a38e5576) for more detailed instructions. "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "py35-paddle1.2.0"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
